{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook to try out CALDERA decomposition on a random matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import sys\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_dir = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
    "\n",
    "if src_dir not in sys.path:\n",
    "    sys.path.append(src_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.caldera.decomposition.dataclasses import CalderaParams\n",
    "from src.caldera.utils.quantization import QuantizerFactory\n",
    "from src.caldera.decomposition.alg import caldera"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_factory_Q = QuantizerFactory(method=\"uniform\", block_size=64)\n",
    "quant_factor_LR = QuantizerFactory(method=\"uniform\", block_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_params = CalderaParams(\n",
    "    compute_quantized_component=True,\n",
    "    compute_low_rank_factors=True,\n",
    "    Q_bits=4,\n",
    "    L_bits=4,\n",
    "    R_bits=4,\n",
    "    rank=16,\n",
    "    iters=20,\n",
    "    lplr_iters=5,\n",
    "    activation_aware_Q=False,\n",
    "    activation_aware_LR=True,\n",
    "    lattice_quant_Q=False,\n",
    "    lattice_quant_LR=False,\n",
    "    update_order=[\"Q\", \"LR\"],\n",
    "    quant_factory_Q=quant_factory_Q,\n",
    "    quant_factory_LR=quant_factor_LR,\n",
    "    rand_svd=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "\n",
    "W = torch.rand(1024, 1024)\n",
    "X = torch.randn(1024, 2048)\n",
    "H = torch.matmul(X, X.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "caldera_decomposition = caldera(\n",
    "    quant_params=quant_params,\n",
    "    W=W,\n",
    "    H=H,\n",
    "    device=\"cpu\",\n",
    "    use_tqdm=True,\n",
    "    scale_W=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"caldera_decomposition.Q.shape: {caldera_decomposition.Q.shape}\")\n",
    "print(f\"caldera_decomposition.L.shape: {caldera_decomposition.L.shape}\")\n",
    "print(f\"caldera_decomposition.R.shape: {caldera_decomposition.R.shape}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.11 (venv)",
   "language": "python",
   "name": "venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
